{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Welcome to Trident: A Python Package for Whole-Slide Image Processing \n",
    "\n",
    "\n",
    "This tutorial will guide you step-by-step to process a single whole-slide image (WSI) using Trident:\n",
    "\n",
    "- Tissue vs. background segmentation \n",
    "- Tissue coordinate extraction\n",
    "- Tissue feature extraction\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 0- Installation \n",
    "\n",
    "\n",
    "```\n",
    "conda create -n \"trident\" python=3.10\n",
    "conda activate trident\n",
    "git clone git@github.com:mahmoodlab/trident.git && cd trident\n",
    "pip install . -e\n",
    "\n",
    "```\n",
    "\n",
    "Please refer to the FAQ if you face installation issues."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1- Tissue vs background segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import torch \n",
    "from PIL import Image\n",
    "import geopandas as gpd\n",
    "from IPython.display import display\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "from trident import OpenSlideWSI\n",
    "from trident.segmentation_models import segmentation_model_factory\n",
    "\n",
    "# a. Download a WSI\n",
    "OUTPUT_DIR = \"tutorial-1/\"\n",
    "DEVICE = f\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "WSI_FNAME = '394140.svs'\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "local_wsi_dir = snapshot_download(\n",
    "    repo_id=\"MahmoodLab/unit-testing\",\n",
    "    repo_type='dataset',\n",
    "    local_dir=os.path.join(OUTPUT_DIR, 'wsis'),\n",
    "    allow_patterns=[WSI_FNAME]\n",
    ")\n",
    "\n",
    "# b. Create OpenSlideWSI\n",
    "wsi_path = os.path.join(local_wsi_dir, WSI_FNAME)\n",
    "slide = OpenSlideWSI(slide_path=wsi_path, lazy_init=False)\n",
    "\n",
    "# c. Run segmentation \n",
    "segmentation_model = segmentation_model_factory(\"hest\")\n",
    "geojson_contours = slide.segment_tissue(segmentation_model=segmentation_model, target_mag=10, job_dir=OUTPUT_DIR, device=DEVICE)\n",
    "\n",
    "# d. Visualize contours\n",
    "contour_image = Image.open(os.path.join(OUTPUT_DIR, 'contours', WSI_FNAME.replace('.svs', '.jpg')))\n",
    "display(contour_image)\n",
    "\n",
    "# e. Check contours saved into GeoJSON with GeoPandas\n",
    "gdf = gpd.read_file(geojson_contours)\n",
    "gdf.head(n=10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2- Tissue coordinate extraction\n",
    "\n",
    "We are patching the whole-slide image into non-overlapping patches of size 256x256 at 20x magnification (0.5 um/px)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py \n",
    "\n",
    "TARGET_MAG = 20\n",
    "PATCH_SIZE = 256\n",
    "\n",
    "# a. Run patch coordinate extraction\n",
    "coords_path = slide.extract_tissue_coords(\n",
    "    target_mag=TARGET_MAG,\n",
    "    patch_size=PATCH_SIZE,\n",
    "    save_coords=OUTPUT_DIR\n",
    ")\n",
    "\n",
    "# b. Visualize\n",
    "viz_coords_path = slide.visualize_coords(\n",
    "    coords_path=coords_path,\n",
    "    save_patch_viz=os.path.join(OUTPUT_DIR, \"visualization\")\n",
    ")\n",
    "display(Image.open(viz_coords_path))\n",
    "\n",
    "# c. Inspect h5 with patch coordinates \n",
    "def print_attrs(name, obj):\n",
    "    print(f\"Object: {name}\")\n",
    "    for key, value in obj.attrs.items():\n",
    "        print(f\"  Attribute - {key}: {value}\")\n",
    "\n",
    "with h5py.File(coords_path, 'r') as h5_file:\n",
    "    print(\"Contents and Attributes in patch coords file:\")\n",
    "    h5_file.visititems(print_attrs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3- Patch feature extraction with the UNI model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trident.patch_encoder_models import encoder_factory\n",
    "\n",
    "PATCH_ENCODER = \"uni_v1\" # Visit the factory or check the README for a list of all available models\n",
    "\n",
    "# a. Instantiate UNI model using the factory \n",
    "encoder = encoder_factory(PATCH_ENCODER)\n",
    "encoder.eval()\n",
    "encoder.to(DEVICE)\n",
    "\n",
    "# b. Run UNI feature extraction\n",
    "features_dir = os.path.join(OUTPUT_DIR, f\"features_{PATCH_ENCODER}\")\n",
    "feats_path = slide.extract_patch_features(\n",
    "    patch_encoder=encoder,\n",
    "    coords_path=coords_path,\n",
    "    save_features=features_dir,\n",
    "    device=DEVICE\n",
    ")\n",
    "\n",
    "# c. Inspect h5 with patch features \n",
    "with h5py.File(feats_path, 'r') as h5_file:\n",
    "    print(\"Contents and Attributes in feats file:\")\n",
    "    h5_file.visititems(print_attrs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trident",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
